package com.cunyu.dao;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import javax.annotation.Resource;
import javax.sql.DataSource;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.jdbc.core.BeanPropertyRowMapper;
import org.springframework.jdbc.core.ColumnMapRowMapper;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.SingleColumnRowMapper;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;

import com.cunyu.util.ReflectUtil;

/**

 * 增删改基类
	NamedParameterJdbcTemplate
 */
public abstract class BaseDao extends JdbcTemplate {
	
	protected  Log logger = LogFactory.getLog(this.getClass());
	

	@Resource   
	@Override
	public void setDataSource(DataSource dataSource) {
		super.setDataSource(dataSource);
	}

	/**
	 * 得到Connection 需要手动关闭, 用于扩展
	 */
	@Deprecated
	public Connection getConn() throws SQLException {
		return super.getDataSource().getConnection();
	}
	
	
	/**
	 * 扩展 返回LinkedHashMap, 用于 excel 的导出 
	 * 列名自动与查询列名保持一致 的时候
	 * @param sql
	 * @param obj
	 * @return
	 */
	public List<Map<String, Object>> getRecord(String sql, Object... obj){
		Connection conn =  null;
		PreparedStatement state = null;
		ResultSet res = null;
		ResultSetMetaData md = null;
		List<Map<String, Object>> list = new ArrayList<Map<String, Object>>();
		try {
			conn = getConn();
			state = conn.prepareStatement(sql);
			if (obj != null) {
				for (int i = 0; i < obj.length; i++) {
					state.setObject(i + 1, obj[i]);
				}
			}
			res = state.executeQuery();
			md = res.getMetaData();
		
			while(res.next()){
				LinkedHashMap<String, Object> link = new LinkedHashMap<String, Object>();
				for(int i=1;i<=md.getColumnCount();i++){
					link.put(md.getColumnLabel(i), res.getObject(i));
				}
				list.add(link);
			}
		} catch (SQLException e) {
			e.printStackTrace();
		}finally{
			if(conn!=null)
				try {
					conn.close();
				} catch (SQLException e) {}
		}
		return list;
		
	}

	/**
	 * 执行 sql语句, 增删改
	 */
	public int exec(String sql, Object... args) {
		logger.info(sql + " - " + Arrays.asList(args));
		return super.update(sql, args);
	}
	
	/**
	 * 插入一条数据返回 插入ID
	 * @param sql
	 * @param args
	 * @return
	 */
	public int insert(final String sql, final Object... args){
		KeyHolder keyHolder = new GeneratedKeyHolder();
		super.update(new PreparedStatementCreator() {
                public PreparedStatement createPreparedStatement(Connection con) throws SQLException{
                    PreparedStatement ps = con.prepareStatement(sql,1);
                    bindInputStatement(ps, args);
                    return ps;
                }
            }, keyHolder);
		Object id =  keyHolder.getKeys().get("id");
		return id!=null ? (int)id : 0 ;
	}
	

	public int exec(Query query) {
		return exec(query.getSql(), query.getParams());
	}
	
	public int[] batchUpdate(String sql, Object... obj){
		try{
			return batch(sql, obj);
		}catch(SQLException e){
			throw new RuntimeException(e);
		}
	}
	private void bindInputStatement(PreparedStatement statement, Object... obj) throws SQLException {
		if (obj != null) {
			for (int i = 0; i < obj.length; i++) {
				statement.setObject(i+1, obj[i]);
			}
		}
	}
	private int[] batch(String sql, Object... obj) throws SQLException{
		Connection conn = getConn();
		try{
			PreparedStatement pst = conn.prepareStatement(sql);
			for(int i=0;i<obj.length;i++){
				bindInputStatement(pst,obj[i]);
				pst.addBatch();
			}
			return pst.executeBatch();
		}finally{
			logger.info(sql + " - " + Arrays.asList(obj));
			if(conn!=null) conn.close();
		}
	}

	@SuppressWarnings("unchecked")
	public int insertDB(Object obj, String tableName) {
		Map<String, Object> map = null;
		if (obj instanceof Map<?, ?>) {
			map = (Map<String, Object>) obj;
		} else {
			map = ReflectUtil.beanToMap(obj);
		}
		Query query = getInsertBody(map, tableName);
		return this.exec(query.getSql(), query.getParams());
	}

	@SuppressWarnings("unchecked")
	public int updateDB(Object obj, String tableName, String where) {
		Map<String, Object> map = null;
		if (obj instanceof Map<?, ?>) {
			map = (Map<String, Object>) obj;
		} else {
			map = ReflectUtil.beanToMap(obj);
		}
		Query query = getUpdateBody(map, tableName, where);
		return this.exec(query.getSql(), query.getParams());
	}

	/**
	 * sql语句 查询方法
	 */
	@Override
	public <T> List<T> query(String sql, RowMapper<T> rowMapper, Object... args) {
		logger.info(sql + " - " + Arrays.asList(args));
		try{
			return super.query(sql, rowMapper, args);
		}catch(Exception e){
			e.printStackTrace();
			return null;
		}
	}

	public List<Map<String, Object>> getList(String sql, Object... args) {
		return query(sql, new ColumnMapRowMapper(), args);
	}
	
	
	/*获取List<T> 单条数据sql 只取第一个字段*/ 
	public <T> List<T> getList(String sql, Class<T> _class, Object... args) {
		return query(sql, new SingleColumnRowMapper<T>(_class), args); 
	}
	
	public Map<String, Object> getMap(String sql, Object... args) {
		List<Map<String, Object>> list = this.getList(sql, args);
		return list.size() > 0 ? list.get(0) : new HashMap<String, Object>();
	}

	public List<Map<String, Object>> getList(Query query) {
		return query(query.getSql(), new ColumnMapRowMapper(), query.getParams());
	}
	

	public <T> List<T> getBeanList(String sql, Class<T> _class, Object... args) {
		return query(sql, resultBeanMapper(_class), args);
	}
	
	public <T> T getBean(String sql, Class<T> _class, Object... args) {
		List<T> list = this.getBeanList(sql, _class, args);
		return  list.size()>0?list.get(0):null;
	}

	public <T> List<T> getBeanList(Query query, Class<T> _class) {
		return query(query.getSql(), resultBeanMapper(_class), query.getParams());
	}

	@Override
	public <T> T queryForObject(String sql, Class<T> requiredType, Object... args){
		logger.info(sql + " - " + Arrays.asList(args));
		return super.queryForObject(sql, args, getSingleColumnRowMapper(requiredType));
	}

	public int getInt(String sql, Object... args) {
		try{
			return this.queryForObject(sql, Integer.class, args);
		}catch(EmptyResultDataAccessException e){
			return -1;
		}catch(NullPointerException e){
			return 0;
		}
	}

	public String getStr(String sql, Object... args) {
		return this.queryForObject(sql, String.class, args);
	}

	private <T> BeanPropertyRowMapper<T> resultBeanMapper(Class<T> clazz) {
		return new BeanPropertyRowMapper<T>(clazz);
	}

	
	/**
	 * mysql分页
	 */
	public List<Map<String, Object>> split(String sql, int start, int cnt, Object... obj){
		sql = sql+ " LIMIT "+cnt+" OFFSET " +start;
		return getList(sql, obj);
	}
	
	/**
	 * mysql分页
	 */
	public <T> List<T> split(String sql,Class<T> _class, int start, int cnt, Object... obj){
		sql = sql+ " LIMIT "+cnt+" OFFSET " +start;
		return getList(sql,_class, obj);
	}
	
	/**
	 * 
	 * @param map
	 * @param _tableName
	 * @return
	 */
	private Query getInsertBody(Map<String, Object> map, String _tableName) {
		Query insert = new DBQuery(new StringBuffer("INSERT INTO " + _tableName + "("));
		String field = "";
		String value = "VALUES( ";
		Iterator<String> ite = map.keySet().iterator();
		while (ite.hasNext()) {
			String key = ite.next();
			Object obj = map.get(key);
			if (field.length() > 0) {
				field += ", ";
				value += ", ";
			}
			field += key;
			value += "?";
			insert.setParams(obj);
		}
		field += ")";
		value += ")";
		insert.add(field + value);
		return insert;
	}

	private Query getUpdateBody(Map<String, Object> _map, String _tableName, String _where) {
		Query update = new DBQuery(new StringBuffer("UPDATE " + _tableName + " SET "));
		String field = "";
		Iterator<String> ite = _map.keySet().iterator();
		while (ite.hasNext()) {
			String key = ite.next();
			Object obj = _map.get(key);
			if (field.length() > 0) {
				field += ", ";
			}
			field += key + " =?";
			update.setParams(obj);
		}
		update.add(field + " WHERE " + _where);
		return update;
	}
	
}